{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "adf933f7",
   "metadata": {},
   "source": [
    "# L2: Image captioning app 🖼️📝"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b8996b4",
   "metadata": {},
   "source": [
    "Load your HF API key and relevant Python libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3471d1ed-41a0-473c-b3c7-99a9e14dffaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import io\n",
    "import IPython.display\n",
    "from PIL import Image\n",
    "import base64 \n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "_ = load_dotenv(find_dotenv()) # read local .env file\n",
    "hf_api_key = os.environ['HF_API_KEY']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05896028-3b43-408e-a899-109bde9625e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Helper functions\n",
    "import requests, json\n",
    "\n",
    "#Image-to-text endpoint\n",
    "def get_completion(inputs, parameters=None, ENDPOINT_URL=os.environ['HF_API_ITT_BASE']):\n",
    "    headers = {\n",
    "      \"Authorization\": f\"Bearer {hf_api_key}\",\n",
    "      \"Content-Type\": \"application/json\"\n",
    "    }\n",
    "    data = { \"inputs\": inputs }\n",
    "    if parameters is not None:\n",
    "        data.update({\"parameters\": parameters})\n",
    "    response = requests.request(\"POST\",\n",
    "                                ENDPOINT_URL,\n",
    "                                headers=headers,\n",
    "                                data=json.dumps(data))\n",
    "    return json.loads(response.content.decode(\"utf-8\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2541e21e",
   "metadata": {},
   "source": [
    "## Building an image captioning app "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1ca462d",
   "metadata": {},
   "source": [
    "Here we'll be using an [Inference Endpoint](https://huggingface.co/inference-endpoints) for `Salesforce/blip-image-captioning-base` a 14M parameter captioning model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b98cf3f7",
   "metadata": {},
   "source": [
    "The free images are available on: https://free-images.com/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb7154b9-fa1c-416e-8ee2-635c27720539",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_url = \"https://free-images.com/sm/9596/dog_animal_greyhound_983023.jpg\"\n",
    "display(IPython.display.Image(url=image_url))\n",
    "get_completion(image_url)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90da616d",
   "metadata": {},
   "source": [
    "## Captioning with `gr.Interface()`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3bee33c-5e85-4677-b25a-fcee2b36703a",
   "metadata": {},
   "source": [
    "#### gr.Image()\n",
    "- The `type` parameter is the format that the `fn` function expects to receive as its input.  If `type` is `numpy` or `pil`, `gr.Image()` will convert the uploaded file to this format before sending it to the `fn` function.\n",
    "- If `type` is `filepath`, `gr.Image()` will temporarily store the image and provide a string path to that image location as input to the `fn` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8021fac0-9300-44d5-adb1-6ac7111f38f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gradio as gr \n",
    "\n",
    "def image_to_base64_str(pil_image):\n",
    "    byte_arr = io.BytesIO()\n",
    "    pil_image.save(byte_arr, format='PNG')\n",
    "    byte_arr = byte_arr.getvalue()\n",
    "    return str(base64.b64encode(byte_arr).decode('utf-8'))\n",
    "\n",
    "def captioner(image):\n",
    "    base64_image = image_to_base64_str(image)\n",
    "    result = get_completion(base64_image)\n",
    "    return result[0]['generated_text']\n",
    "\n",
    "gr.close_all()\n",
    "demo = gr.Interface(fn=captioner,\n",
    "                    inputs=[gr.Image(label=\"Upload image\", type=\"pil\")],\n",
    "                    outputs=[gr.Textbox(label=\"Caption\")],\n",
    "                    title=\"Image Captioning with BLIP\",\n",
    "                    description=\"Caption any image using the BLIP model\",\n",
    "                    allow_flagging=\"never\",\n",
    "                    examples=[\"christmas_dog.jpeg\", \"bird_flight.jpeg\", \"cow.jpeg\"])\n",
    "\n",
    "demo.launch(share=True, server_port=int(os.environ['PORT1']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "505950a7-c516-4dcc-bb50-a0628e7b57b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "gr.close_all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43cf5022",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
